Compare Cell Entry Effects¶

In this notebook, we'll investigate whether the same mutation affects cell entry differently across three cell types (293T-Mxra8, 293T-TIM1, and C6/36).

Although Mxra8 serves as a receptor and TIM1 as an entry factor for CHIKV in humans, the mosquito receptor remains unknown. By identifying sites where mutations affect cell entry differently in mosquito cells (C6/36) than in human cells (293T-Mxra8 and 293T-TIM1), we may uncover sites involved in binding to the unknown mosquito receptor.

In [1]:
import itertools
import os

import altair as alt

import dmslogo.colorschemes

import numpy

import pandas as pd

import polyclonal.alphabets

import scipy.spatial.distance

# Remove the limit of ~5000 rows -- maybe there are better ways? (https://altair-viz.github.io/user_guide/large_datasets.html)
_ = alt.data_transformers.disable_max_rows()

Get the input parameters¶

The notebook is designed to be parameterized by papermill. The next cell is tagged parameters:

In [2]:
# This cell is tagged parameters, so the values defined here will be overwritten
# by the `papermill` parameterization.

# CSV with filtered data
mut_effects_csv = "../results/summaries/entry_293T-Mxra8_C636_293T-TIM1_Mxra8-binding.csv"

# cells and their names in input file
cells = {"C6/36": "C636", "293T-Mxra8": "293T_Mxra8", "293T-TIM1": "293T_TIM1"}

# for calculating differences and display, floor mutation effects at this
floor_mut_effects = -5

# output files
site_diffs_csv = "../results/compare_cell_entry/site_diffs.csv"
mut_scatter_chart = "../results/compare_cell_entry/compare_cell_entry_scatter.html"
site_zoom_chart = "../results/compare_cell_entry/compare_cell_entry_site_zoom.html"
In [3]:
# Parameters
cells = {"293T-Mxra8": "293T_Mxra8", "C6/36": "C636", "293T-TIM1": "293T_TIM1"}
floor_mut_effects = -5
mut_effects_csv = "results/summaries/entry_293T-Mxra8_C636_293T-TIM1_Mxra8-binding.csv"
site_diffs_csv = "results/compare_cell_entry/site_diffs.csv"
mut_scatter_chart = "results/compare_cell_entry/compare_cell_entry_scatter.html"
site_zoom_chart = "results/compare_cell_entry/compare_cell_entry_site_zoom.html"

Read the data¶

For this analysis, we'll need the effects of mutations on cell entry in each cell line.

These are pre-filtered (for QC metrics) values:

In [4]:
print(f"Reading mutation effects from {mut_effects_csv=}")
mut_effects = pd.read_csv(mut_effects_csv)

mut_effects
Reading mutation effects from mut_effects_csv='results/summaries/entry_293T-Mxra8_C636_293T-TIM1_Mxra8-binding.csv'
Out[4]:
site wildtype mutant entry in 293T_Mxra8 cells entry in C636 cells entry in 293T_TIM1 cells binding to mouse Mxra8 sequential_site region
0 -1(E3) M I -7.5410 -7.514 -7.50200 NaN 1 E3
1 -1(E3) M M 0.0000 0.000 0.00000 0.00000 1 E3
2 -1(E3) M T -7.5630 -7.541 -7.57600 NaN 1 E3
3 1(6K) A A 0.0000 0.000 0.00000 0.00000 489 6K
4 1(6K) A C 0.1786 0.035 0.02934 -0.02603 489 6K
... ... ... ... ... ... ... ... ... ...
19453 99(E2) H S -7.2690 -7.132 -6.60600 NaN 164 E2
19454 99(E2) H T -7.4930 -6.834 -6.99100 NaN 164 E2
19455 99(E2) H V -7.5370 -7.494 -7.41200 NaN 164 E2
19456 99(E2) H W -7.0080 -6.427 -5.61900 NaN 164 E2
19457 99(E2) H Y -3.1680 -1.463 -1.22800 -0.41320 164 E2

19458 rows × 9 columns

Get the data tidy format:

In [5]:
col_to_cell = {f"entry in {label} cells": cell for (cell, label) in cells.items()}

assert set(col_to_cell).issubset(mut_effects.columns), f"{col_to_cell=}, {mut_effects.columns=}"

mut_effects_tidy = (
    mut_effects.rename(columns=col_to_cell)
    .melt(
        id_vars=["site", "sequential_site", "wildtype", "mutant", "region"],
        value_vars=col_to_cell.values(),
        var_name="cell",
        value_name="effect",
    )
    .sort_values("sequential_site")
)

mut_effects_tidy
Out[5]:
site sequential_site wildtype mutant region cell effect
0 -1(E3) 1 M I E3 293T-Mxra8 -7.5410
38916 -1(E3) 1 M I E3 293T-TIM1 -7.5020
38917 -1(E3) 1 M M E3 293T-TIM1 0.0000
19458 -1(E3) 1 M I E3 C6/36 -7.5140
19459 -1(E3) 1 M M E3 C6/36 0.0000
... ... ... ... ... ... ... ...
35495 439(E1) 988 H F E1 C6/36 -0.7510
35496 439(E1) 988 H G E1 C6/36 -0.5489
35510 439(E1) 988 H Y E1 C6/36 -0.5264
16040 439(E1) 988 H I E1 293T-Mxra8 -0.7279
16041 439(E1) 988 H K E1 293T-Mxra8 0.0487

58374 rows × 7 columns

Scatter plots of cell entry for each cell¶

How does the same mutation affect entry in each cell line? We'll plot the effect of each mutation between pairs of cell lines to determine if there are global differences.

In [6]:
def plot_mut_scatter_chart(
    data,
    condition,
    value,
    groupby=['site', 'mutant', 'wildtype', 'sequential_site'],
    color=None,
    label_suffix="",
    init_floor_value=-6,
):
    """
    Make an Altair scatter plot comparing mutant-level values for each condition.

    Parameters
    ----------
    data : pd.DataFrame
        The long-form data to plot
    conditions: str
        The column containing the condition labels (i.e. TIM1, MXRA8, C636)
    value : str
        The column containing the values to compare between conditions
    groupby : list of str
        The columns to group the data on (i.e. ['site', 'mutant', 'wildtype'])
    color : str
        The column to color the points and add an interactive legend for
    label_suffix : str
        Label suffixed to x- and y-axis labels.
    init_floor_value : float or None
        Initial value for floor slider for values.

    Returns
    -------
    alt.Chart
        The Altair chart object
    """    
    if 'mutant' not in groupby or 'site' not in groupby:
        raise ValueError("groupby must contain 'mutant' and 'site'")
    
    missing_cols = [col for col in [condition, value] + groupby if col not in data.columns]
    if missing_cols:
        raise ValueError(f"Columns are missing from the data: {missing_cols}")
    
    if color is not None:
        if color not in data.columns:
            raise ValueError(f"Color column '{color}' not found in data")
        groupby.append(color)
    
    conditions = data[condition].unique()

    # pivot the data
    data_wide = (
        data
        .pivot_table(index=groupby, columns=condition, values=value)
        .reset_index()
    )

    tooltips = []
    for col in groupby:
        tooltips.append(alt.Tooltip(f'{col}:N'))
    for col in conditions:
        tooltips.append(alt.Tooltip(f'{col}:Q', format=".2f"))

    brush = alt.selection_interval()
    
    mut_selection = alt.selection_point(on="mouseover", fields=groupby, empty=False)

    min_value_slider = alt.param(
        name="min_value_slider",
        bind=alt.binding_range(
            min=min(data[value]),
            max=max(data[value]),
            name="floor values at this number",
        ),
        value=(
            max(init_floor_value, min(data[value]))
            if init_floor_value is not None
            else min(data[value])
        ),
    )

    base = (
        alt.Chart(data_wide)
        .add_params(mut_selection, brush, min_value_slider)
        .transform_filter(brush)
    )

    scatters = []
    for condition_a, condition_b in itertools.combinations(conditions, 2):
        # Base data for the scatter plot
        scatter = base.transform_filter(
            f'isValid(datum["{condition_a}"]) && isValid(datum["{condition_b}"])'
        ).transform_calculate(
            condition_a_floored=f'max(datum["{condition_a}"], min_value_slider)',
            condition_b_floored=f'max(datum["{condition_b}"], min_value_slider)',
        ).encode(
            x=alt.X(
                "condition_a_floored:Q",
                title=condition_a + label_suffix,
                scale=alt.Scale(padding=10, nice=False, zero=False),
                axis=alt.Axis(titleFontSize=14, labelFontSize=11, labelOverlap="greedy"),
            ),
            y=alt.Y(
                "condition_b_floored:Q",
                title=condition_b + label_suffix,
                scale=alt.Scale(padding=10, nice=False, zero=False),
                axis=alt.Axis(titleFontSize=14, labelFontSize=11, labelOverlap="greedy"),
            ),
        ).properties(
            title=alt.TitleParams(f'{condition_a} vs {condition_b}', fontSize=16),
            width=250,
            height=250
        )
        # Background points to show the full range of data when brushing
        background = scatter.mark_point(
            filled=True,
            size=25,
            color='lightgray',
            opacity=0.3,
        )
        # Foreground points have tooltips and respond to brushing (and legend selection)
        if color is not None:
            selection = alt.selection_point(fields=[color], bind='legend')
            foreground = scatter.mark_point(
                filled=True,
                fillOpacity=0.5,
                stroke="black",
                strokeOpacity=1,
            ).encode(
                color=alt.Color(color, type='nominal').scale(domain=data[color].unique()),
                strokeWidth=alt.condition(mut_selection, alt.value(3), alt.value(0)),
                size=alt.condition(mut_selection, alt.value(80), alt.value(40)),
                tooltip=tooltips,
            ).add_params(
                selection
            ).transform_filter(selection)
        else:
            foreground = scatter.mark_point(
                filled=True,
                color='steelblue',
                fillOpacity=0.5,
                stroke="black",
                strokeOpacity=1,
            ).encode(
                tooltip=tooltips,
                strokeWidth=alt.condition(mut_selection, alt.value(3), alt.value(0)),
                size=alt.condition(mut_selection, alt.value(70), alt.value(35)),
            )

        scatters.append((background + foreground))

    chart = alt.hconcat(*scatters).configure_axis(grid=False).configure_legend(
        titleFontSize=14, labelFontSize=14
    )

    return chart
In [7]:
mut_scatter = plot_mut_scatter_chart(
    mut_effects_tidy,
    "cell",
    "effect", 
    color="region",
    label_suffix=" cell entry",
    init_floor_value=floor_mut_effects,
)

print(f"Saving chart to {mut_scatter_chart=}")
os.makedirs(os.path.dirname(mut_scatter_chart), exist_ok=True)
mut_scatter.save(mut_scatter_chart)

mut_scatter
Saving chart to mut_scatter_chart='results/compare_cell_entry/compare_cell_entry_scatter.html'
Out[7]:
  • Mouseover on points to see a tooltip with information about that mutation.
  • Hold Click and Drag over points to show only those mutations.
  • Click on conditions in the legend to show only that condition (region).
  • Use the slider to floor values at some mimum plot value.
  • Double Click on the plot or legend to reset the plot.

Points with color show the active selection and gray points show total distribution of the data.

Identify sites where mutations have different effects in each cell¶

Compute site differences between conditions¶

We use three different site-level metrics for the differences between conditions:

  • mean difference: The mean difference in effect on cell entry for all non-wildtype amino acids at each site in cell_1 minus cell_2. We compute this mean after flooring all cell entry effects at the value specified by floor_mut_effects.
  • Jensen-Shannon divergence: A "probability" is assigned to each amino acid at each site as proportional exp(effect), and then the Jensen-Shannon divergence is computed for the probabilities for cell_1 versus cell_2.
  • difference in constraint: A "probability" is assigned to each amino acid as proportional exp(effect), and then the number of effective amino acids at each site is computed for each cell, and we report the number for cell_1 minus cell_2.
In [8]:
# first get color to use for each amino-acid in scatter plot
# this also defines list of amino acids to keep
aa_color_df = (
    pd.Series(dmslogo.colorschemes.AA_FUNCTIONAL_GROUP)
    .rename_axis("mutant")
    .rename("color")
    .reset_index()
)
aas = polyclonal.alphabets.biochem_order_aas(polyclonal.alphabets.AAS)
assert set(aa_color_df["mutant"]) == set(aas)

# get mutation level data, just for amino acids
assert set(cells) == set(mut_effects_tidy["cell"])
mut_data = (
    mut_effects_tidy
    .query("mutant in @aas")
    .pivot_table(
        index=["site", "sequential_site", "wildtype", "mutant", "region"],
        columns="cell",
        values="effect",
    )
    .sort_values("sequential_site")
    .reset_index()
)
assert set(mut_data["wildtype"]).issubset(aas)

# get site difference data
def get_site_diffs(df):
    is_wildtype = df.iloc[:, 0]
    s1 = df.iloc[:, 1]
    s2 = df.iloc[:, 2]
    # simple mean difference across non-wildtype sites
    mean_diff = (s1.clip(lower=floor_mut_effects) - s2.clip(lower=floor_mut_effects))[~is_wildtype].mean()
    # relative entropy
    p1 = numpy.exp(s1[s1.notnull() & s2.notnull()])
    p2 = numpy.exp(s2[s1.notnull() & s2.notnull()])
    assert len(p1) == len(p2)
    if len(p1):
        p1 /= p1.sum()
        p2 /= p2.sum()
        jsd = scipy.spatial.distance.jensenshannon(p1, p2)**2
    else:
        jsd = 0
    # difference in n_effective
    if len(p1) == 0:
        n_eff_diff = 0
    else:
        n_eff_1 = len(aas)**(-p1 * numpy.log(p1) / numpy.log(len(aas))).sum()
        n_eff_2 = len(aas)**(-p2 * numpy.log(p2) / numpy.log(len(aas))).sum()
        n_eff_diff = n_eff_1 - n_eff_2
    return pd.Series(
        {
            "mean difference": mean_diff,
            "Jensen-Shannon divergence": jsd,
            "difference in constraint": n_eff_diff,
        }
    )
    
site_diff_metrics = [
    "difference in constraint", "mean difference", "Jensen-Shannon divergence"
]
site_diffs = []
for cell_1, cell_2 in itertools.combinations(cells, 2):
    site_diffs.append(
        mut_data
        .assign(is_wildtype=lambda x: x["mutant"] == x["wildtype"])
        .groupby(["site", "sequential_site", "region"])
        [["is_wildtype", cell_1, cell_2]]
        .apply(get_site_diffs)
        .assign(cell_1=cell_1, cell_2=cell_2)
        .sort_values("sequential_site")
        .reset_index()
    )
site_diffs = pd.concat(site_diffs, ignore_index=True)
assert set(site_diff_metrics).issubset(site_diffs.columns)

print(f"For mean difference, effects floored at {floor_mut_effects=} first.")
print(f"Saving site differences to {site_diffs_csv=}")
site_diffs.to_csv(site_diffs_csv, index=False, float_format="%.3f")
site_diffs
For mean difference, effects floored at floor_mut_effects=-5 first.
Saving site differences to site_diffs_csv='results/compare_cell_entry/site_diffs.csv'
Out[8]:
site sequential_site region mean difference Jensen-Shannon divergence difference in constraint cell_1 cell_2
0 -1(E3) 1 E3 0.000000 8.062812e-08 -0.000198 293T-Mxra8 C6/36
1 1(E3) 2 E3 0.118259 1.126940e-02 0.854814 293T-Mxra8 C6/36
2 2(E3) 3 E3 -0.061232 2.247468e-02 0.370147 293T-Mxra8 C6/36
3 3(E3) 4 E3 -0.031816 5.355085e-03 0.956139 293T-Mxra8 C6/36
4 4(E3) 5 E3 -0.092037 6.106870e-03 -0.343112 293T-Mxra8 C6/36
... ... ... ... ... ... ... ... ...
2959 435(E1) 984 E1 0.222021 1.957161e-02 -0.247524 C6/36 293T-TIM1
2960 436(E1) 985 E1 0.197230 1.346945e-02 0.012677 C6/36 293T-TIM1
2961 437(E1) 986 E1 0.329445 1.677606e-02 0.344722 C6/36 293T-TIM1
2962 438(E1) 987 E1 0.205787 3.940191e-02 0.395427 C6/36 293T-TIM1
2963 439(E1) 988 E1 0.209640 1.852273e-02 0.094544 C6/36 293T-TIM1

2964 rows × 8 columns

Plot sites with large differences¶

We make an interactive plot that includes:

  • line plot with site differences at top left
  • scatter plot of mutation effects at top right
  • heatmaps centered around key site at bottom

You can click sites on the site plot to show them on the mutation-level plots, zoom with the zoom bar, and use = menu at the bottom to adjust other options including which cells to compare.

In [9]:
def plot_site_comparison(
    mut_data,
    site_diffs,
    cells,
    site_diff_metrics,
    aas,
    aa_color_df,
    init_floor_effect,
    heatmap_max_at_least=2,
    heatmap_flank=12,
):
    """Plot (site-level) difference of entry effects between cells w mutation zooms."""

    # some params
    site_chart_width = 700

    assert set(mut_data["site"]) == set(site_diffs["site"])
    assert set(site_diff_metrics).issubset(site_diffs.columns)

    # Drag to zoom into sites on the x-axis colored by region
    zoom_selection = alt.selection_interval(
        encodings=["x"],
        mark=alt.BrushConfig(stroke='black', strokeWidth=2)
    )

    # zoom bar
    zoom_bar = (
        alt.Chart(mut_data[["site", "sequential_site", "region"]])
        .mark_rect()
        .encode(
            alt.X(
                "site:N",
                sort=alt.SortField("sequential_site"),
                title="click and drag to zoom on sites",
                axis=alt.Axis(ticks=False, labels=False, titleFontWeight="normal"),
            ),
            alt.Color("region", scale=alt.Scale(scheme="greys"), legend=None),
            tooltip=["site", "sequential_site", "region"],
        )
        .properties(width=site_chart_width, height=10)
        .add_params(zoom_selection)
    )

    # line plot
    metric_selection = alt.selection_point(
        fields=["metric"],
        name="metric_selection",
        value=site_diff_metrics[0],
        bind=alt.binding_select(
            options=site_diff_metrics,
            name="metric for site differences between cells",
        ),
    )

    cell_1_options = [c for c in cells if c in set(site_diffs["cell_1"])]
    cell_1_selection = alt.param(
        name="cell_1",
        value=cell_1_options[0],
        bind=alt.binding_select(
            options=cell_1_options,
            name="comparator cell line",
        )
    )

    cell_2_options = [c for c in cells if c in set(site_diffs["cell_2"])]
    cell_2_selection = alt.param(
        name="cell_2",
        value=cell_2_options[0],
        bind=alt.binding_select(
            options=cell_2_options,
            name="reference cell line",
        )
    )

    # site w biggest effect
    default_site = (
        site_diffs[
            (site_diffs["cell_1"] == cell_1_options[0])
            & (site_diffs["cell_2"] == cell_2_options[0])
        ]
        .set_index("site")
        [site_diff_metrics[0]]
        .abs()
        .sort_values(ascending=False)
        .index[0]
    )
    default_sequential_site = site_diffs.set_index("site")["sequential_site"].to_dict()[default_site]

    site_selection = alt.selection_point(
        fields=["site"], empty=False, value=default_site, on="click"
    )
    sequential_site_selection = alt.selection_point(
        fields=["sequential_site"],
        empty=False,
        value=default_sequential_site,
        on="click",
    )
    
    site_base = (
        alt.Chart(site_diffs)
        .transform_filter(zoom_selection)
        .transform_filter(
            (alt.datum["cell_1"] == cell_1_selection)
            & (alt.datum["cell_2"] == cell_2_selection)
        )
        .transform_fold(
            site_diff_metrics,
            ["metric", "difference"],
        )
        .transform_filter(metric_selection)
        .encode(
            alt.X(
                "site:N",
                sort=alt.SortField("sequential_site"),
                title=None,
                axis=alt.Axis(labelOverlap="greedy", ticks=False),
            ),
            alt.Y(
                "difference:Q",
                title="difference at site",
                scale=alt.Scale(nice=False, padding=9),
            ),
            tooltip=[
                "site", "sequential_site", "region", alt.Tooltip("difference:Q", format=".2f")
            ],
        )
    )
    
    site_lines = site_base.mark_line(color="black", strokeWidth=1, opacity=1)

    site_points = site_base.mark_circle(filled=True, fill="black", stroke="gold", opacity=1).encode(
        strokeWidth=alt.condition(site_selection, alt.value(3), alt.value(0)),
        size=alt.condition(site_selection, alt.value(180), alt.value(60)),
    )

    # Dynamic title for chart plot
    site_title = alt.TitleParams(
        alt.expr(
            f'"difference between mutation effects in " + {cell_1_selection.name} + " versus " + {cell_2_selection.name} + " cells"'
        ),
        subtitle="click on a site to show in the mutation-level scatter plot and heatmaps",
        anchor="middle",
    )

    site_chart = (
        (site_lines + site_points)
        .properties(width=site_chart_width, height=185, title=site_title)
        .add_params(
            metric_selection, site_selection, sequential_site_selection, cell_1_selection, cell_2_selection,
        )
    )

    # amino-acid scatter plot for a single site
    min_effect = mut_data[list(cells)].min().min()
    max_effect = mut_data[list(cells)].max().max()
    min_effect_slider = alt.param(
        name="min_effect_slider",
        bind=alt.binding_range(
            min=min_effect, max=max_effect, name="floor displayed mutation effect at",
        ),
        value=max(init_floor_effect, min_effect) if init_floor_effect is not None else min_effect,
    )
    
    mut_base = alt.Chart(mut_data).add_params(min_effect_slider)

    mutant_selection = alt.selection_point(
        fields=["mutant", "site"], on="mouseover", empty=False
    )

    mut_scatter = (
        mut_base
        .transform_filter(site_selection)
        .transform_lookup(
            lookup='mutant',
            from_=alt.LookupData(data=aa_color_df, key='mutant', fields=['color']),
        )
        .transform_calculate(
            x=f"datum[{cell_1_selection.name}]",
            y=f"datum[{cell_2_selection.name}]",
            x_floored=f'isValid(datum.x) ? max(datum.x, {min_effect_slider.name}) : datum.x',
            y_floored=f'isValid(datum.y) ? max(datum.y, {min_effect_slider.name}) : datum.y',
        )
        .encode(
            alt.X("x_floored:Q", title="comparator cell line"),
            alt.Y("y_floored:Q", title="reference cell line"),
            alt.Text("mutant:N"),
            alt.Color("color:N", scale=None),
            size=alt.condition(mutant_selection, alt.value(22), alt.value(18)),
            strokeWidth=alt.condition(mutant_selection, alt.value(1), alt.value(0)),
            fillOpacity=alt.condition(mutant_selection, alt.value(1), alt.value(0.75)),
            tooltip=(
                ["mutant", "wildtype"] + [alt.Tooltip(c, format=".2f") for c in cells]
            )
        )
        .mark_text(stroke="black", strokeOpacity=1, fontWeight=600)
        .add_params(cell_1_selection, cell_2_selection, mutant_selection)
        .properties(
            title=alt.TitleParams(
                alt.expr(f'"mutation effects at site " + {site_selection.name}.site')
            ),
            width=220,
            height=220,
        )
    )

    scatter_diagonal = (
        alt.Chart()
        .mark_rule(color="gray", strokeWidth=3, strokeDash=[6, 6], opacity=0.5)
        .transform_calculate(ax_lim=min_effect_slider.name)
        .encode(
            alt.X("ax_lim:Q", scale=alt.Scale(nice=False, padding=9, zero=False)),
            alt.Y("ax_lim:Q", scale=alt.Scale(nice=False, padding=9, zero=False)),
            x2=alt.datum(max_effect),
            y2=alt.datum(max_effect),
        )
    )

    scatter_chart = scatter_diagonal + mut_scatter

    # make the heatmaps
    assert all(mut_data["sequential_site"] == mut_data["sequential_site"].astype(int))
    assert all(site_diffs["sequential_site"] == site_diffs["sequential_site"].astype(int))
    
    mut_base = alt.Chart(mut_data).add_params(min_effect_slider)
    heatmap_base = (
        mut_base
        .transform_filter(
            f"abs(datum.sequential_site - {sequential_site_selection.name}.sequential_site) <= 11"
        )
        .encode(
            alt.X("site", sort=alt.SortField("sequential_site")),
            alt.Y("mutant", sort=aas),
        )
        .properties(width=alt.Step(12), height=alt.Step(12))
    )

    # gray background for missing values
    heatmap_bg = heatmap_base.transform_impute(
        impute="_stat_dummy",
        key="mutant",
        keyvals=aas,
        groupby=["site"],
        value=None,
    ).mark_rect(color="#E0E0E0", opacity=0.8)

    # mark X for wildtype
    heatmap_wildtype = (
        heatmap_base
        .transform_filter(alt.datum["wildtype"] == alt.datum["mutant"])
        .mark_text(text="x", color="black")
    )

    # make heatmap for each cell type
    heatmaps = []
    for cell in cells:
        first_cell = (cell == list(cells)[0])
        heatmap_muts = (
            heatmap_base
            .transform_calculate(
                effect_floored=f'isValid(datum["{cell}"]) ? max(datum["{cell}"], {min_effect_slider.name}) : datum["{cell}"]'
            )
            .encode(
                alt.Y("mutant", sort=aas, title="amino acid" if first_cell else None),
                alt.Color(
                    "effect_floored:Q",
                    title="mutation effect",
                    legend=alt.Legend(
                        orient="right", titleOrient="right", gradientStrokeColor="black", gradientStrokeWidth=1
                    ),
                    scale=alt.Scale(
                        scheme="redblue",
                        nice=False,
                        domainMid=0,
                        domainMax=max(mut_data[list(cells)].max().max(), heatmap_max_at_least),
                    ),
                ),
                strokeWidth=alt.condition(site_selection, alt.value(3), alt.value(1)),
                tooltip=["site", "sequential_site", "wildtype", "mutant"] + [alt.Tooltip(c, format=".2f") for c in cells],
            )
            .mark_rect(stroke="black", opacity=1, strokeOpacity=1)
            .properties(title=f"{cell} effect")
        )
        heatmaps.append(heatmap_bg + heatmap_muts + heatmap_wildtype)

    heatmap = alt.hconcat(*heatmaps, spacing=7)

    # assemble the final chart
    chart = (
        alt.vconcat(
            alt.hconcat(alt.vconcat(site_chart, zoom_bar, spacing=4), scatter_chart),
            heatmap,
        )
        .configure_title(fontSize=18, subtitleFontSize=16)
        .configure_axis(grid=False, labelFontSize=11, titleFontSize=16)
        .configure_legend(labelFontSize=12, titleFontSize=16)
    )

    return chart
In [10]:
site_chart = plot_site_comparison(
    mut_data, site_diffs, cells, site_diff_metrics, aas, aa_color_df, floor_mut_effects, 2, 12
)

alt.renderers.set_embed_options(
    padding={"left": 5, "right": 5, "bottom": 5, "top": 5}
)

print(f"Saving to {site_zoom_chart=}")
site_chart.save(site_zoom_chart)
site_chart
Saving to site_zoom_chart='results/compare_cell_entry/compare_cell_entry_site_zoom.html'
Out[10]:

Make paper figure plots¶

These plots have some manually hardcoded variables unlike the code above.

First, scatter plots of mean effect at each site in different cells:

In [11]:
fig_site_data = (
    mut_data
    .query("wildtype != mutant")
    .groupby(["wildtype", "site", "region"], as_index=False)
    .aggregate(
        **{
            cell: pd.NamedAgg(cell, lambda s: s.clip(lower=floor_mut_effects).mean())
            for cell in cells
        }
    )
)

fig_site_selection = alt.selection_point(fields=["site"], empty=False, on="mouseover")

fig_site_scatter_base = alt.Chart(fig_site_data).add_params(fig_site_selection)
fig_site_scatter_chart = []
for cell1, cell2 in itertools.combinations(reversed(cells), 2):
    fig_site_scatter_chart.append(
        fig_site_scatter_base
        .encode(
            alt.X(cell1, axis=alt.Axis(values=[0, -2, -4]), scale=alt.Scale(nice=False, padding=8)),
            alt.Y(cell2, axis=alt.Axis(values=[0, -2, -4]), scale=alt.Scale(nice=False, padding=8)),
            tooltip=["site", "wildtype"],
            fill=alt.condition(fig_site_selection, alt.value("red"), alt.value("gray")),
            fillOpacity=alt.condition(fig_site_selection, alt.value(1), alt.value(0.25)),
            size=alt.condition(fig_site_selection, alt.value(90), alt.value(35)),
        )
        .mark_circle(fill="gray", fillOpacity=0.25, strokeOpacity=0.7, stroke="black", strokeWidth=0.5)
        .properties(width=127, height=127)
    )
fig_site_scatter_chart = (
    alt.vconcat(*fig_site_scatter_chart, spacing=13)
    .properties(
        title=alt.TitleParams(
            ["average effect of", "mutations at each site"],
            anchor="middle",
            dx=13
        )
    )
)

fig_site_scatter_chart
Out[11]:

Combine the scatter plot with a line plot of the summed difference at each site:

In [12]:
fig_site_diffs = (
    site_diffs
    .assign(comparison=lambda x: x["cell_1"] + " minus " + x["cell_2"])
    [["site", "sequential_site", "region", "comparison", "mean difference"]]
)

fig_site_width = 680

# line chart
fig_site_diffs_chart = (
    alt.Chart(fig_site_diffs)
    .add_params(fig_site_selection)
    .encode(
        alt.X(
            "site",
            sort=alt.SortField("sequential_site"),
            axis=alt.Axis(
                values=fig_site_diffs[["sequential_site", "site"]].drop_duplicates().sort_values("sequential_site")["site"].iloc[30::80],
                labelAngle=0,
            ),
        ),
        alt.Y("mean difference", title=None, scale=alt.Scale(nice=False, padding=4)),
        alt.Row(
            "comparison",
            title=None,
            header=alt.Header(labelFontSize=12),
            spacing=10,
            sort=list(reversed(fig_site_diffs["comparison"].unique())),
        ),
        color=alt.condition(fig_site_selection, alt.value("red"), alt.value("black")),
        tooltip=["site", alt.Tooltip("mean difference", format=".2f", title="difference")],
    )
    .mark_bar(width=2, opacity=1, strokeWidth=0)
    .properties(height=155, width=fig_site_width)
)

# region overlay for line chart
region_chart = (
    alt.Chart(fig_site_diffs[["sequential_site", "region"]].drop_duplicates())
    .encode(
        alt.X("sequential_site:O", axis=None),
        alt.Color(
            "region",
            legend=None,
            scale=alt.Scale(range=["AliceBlue", "CadetBlue", "CadetBlue", "AliceBlue"])
        ),
    )
    .mark_rect(opacity=0.75, strokeWidth=0)
    .properties(width=fig_site_width)
)

text_df = fig_site_diffs.groupby("region", as_index=False).aggregate(x=pd.NamedAgg("sequential_site", "mean"))
text_chart = (
    alt.Chart(text_df)
    .encode(
        alt.X(
            "x:Q",
            title=None,
            scale=alt.Scale(domain=(fig_site_diffs["sequential_site"].min(), fig_site_diffs["sequential_site"].max())),
            axis=None,
        ),
        alt.Text("region"),
    )
    .mark_text(fontWeight="bold", fontSize=13)
    .properties(width=fig_site_width, height=15)
)

overlay_chart = region_chart + text_chart

fig_site_line_chart = (
    alt.vconcat(overlay_chart, fig_site_diffs_chart, spacing=0)
    .properties(
        title=alt.TitleParams("average difference in mutation effects at each site", anchor="middle"),
    )
)

fig_site_chart = alt.hconcat(
    fig_site_scatter_chart,
    fig_site_line_chart,
    spacing=55,
    center=True,
)

fig_site_chart.configure_axis(grid=False, titleFontSize=12, titleFontWeight="normal").configure_view(stroke="black")
Out[12]:

Now make a function to plot the mutation effects in different cells at key sites:

In [13]:
def plot_site_scatter(site, cell_1, cell_2, no_y_axis=None, ax_max=1, bold_letters=None):
    assert site in set(mut_data["site"]), site
    mut_df = (
        mut_data
        [mut_data["site"] == site]
        [["wildtype", "mutant", cell_1, cell_2]]
    )
    for cell in [cell_1, cell_2]:
        mut_df[cell] = mut_df[cell].clip(lower=floor_mut_effects)

    ax_min = floor_mut_effects - 0.5
    if ax_max is None:
        ax_max = mut_data[[cell_1, cell_2]].max().max() + 0.5

    if bold_letters is None:
        mut_df = mut_df.merge(aa_color_df, on="mutant", validate="one_to_one").assign(
            opacity=0.75, strokeWidth=0.4, size=14
        )
    else:
        mut_df["color"] = mut_df.apply(
            lambda r: (
                ("red" if r["mutant"] != r["wildtype"] else "black") if r["mutant"] in bold_letters else "darkblue"
            ),
            axis=1,
        )
        mut_df["opacity"] = mut_df["mutant"].map(lambda a: 1 if a in bold_letters else 0.25)
        mut_df["strokeWidth"] = mut_df["mutant"].map(lambda a: 0.5 if a in bold_letters else 0)
        mut_df["size"] = mut_df["mutant"].map(lambda a: 15 if a in bold_letters else 12)  
    
    mut_scatter = (
        alt.Chart(mut_df)
        .encode(
            alt.X(cell_1, scale=alt.Scale(domain=(ax_min, ax_max), nice=False)),
            alt.Y(
                cell_2,
                title=None if no_y_axis else cell_2,
                scale=alt.Scale(domain=(ax_min, ax_max), nice=False),
                axis=alt.Axis(labels=False) if no_y_axis else alt.Axis()
            ),
            alt.Text("mutant"),
            alt.Color("color:N", scale=None),
            alt.FillOpacity("opacity", scale=None),
            alt.StrokeWidth("strokeWidth", scale=None),
            alt.Size("size", scale=None),
            tooltip=(
                ["mutant", "wildtype"] + [alt.Tooltip(c, format=".2f") for c in [cell_1, cell_2]]
            )
        )
        .mark_text(stroke="black", strokeOpacity=1, fontWeight=700)
        .properties(
            title=alt.TitleParams(f"site {site}", fontSize=12),
            width=100,
            height=100,
        )
    )

    scatter_diagonal = (
        alt.Chart()
        .mark_rule(color="gray", strokeWidth=3, strokeDash=[6, 6], opacity=0.4)
        .encode(
            x=alt.datum(ax_min),
            y=alt.datum(ax_min),
            x2=alt.datum(ax_max),
            y2=alt.datum(ax_max),
        )
    )

    return scatter_diagonal + mut_scatter

Plot the mutation effects top sites of difference in each cell line pair, and merge into one figure:

In [14]:
top_n = 4  # plot top this many sites in each pair

top_diff_sites = (
    site_diffs
    .assign(abs_diff=lambda x: x["mean difference"].abs())
    .sort_values("abs_diff", ascending=False)
    .groupby(["cell_2", "cell_1"])
    .aggregate(sites=pd.NamedAgg("site", lambda s: s[: top_n].tolist()))
    ["sites"].to_dict()
)

top_diff_sites_scatter_chart = []
for cell1, cell2 in itertools.combinations(reversed(cells), 2):
    pair_top_sites = top_diff_sites[(cell1, cell2)]
    top_diff_sites_scatter_chart.append(
        alt.hconcat(
            *[plot_site_scatter(s, cell1, cell2, no_y_axis=(s != pair_top_sites[0])) for s in pair_top_sites],
            spacing=2,
        )
    )
top_diff_sites_scatter_chart = alt.vconcat(
    *top_diff_sites_scatter_chart, spacing=15
)

(
    alt.vconcat(fig_site_chart, top_diff_sites_scatter_chart, spacing=35)
    .configure_axis(grid=False, titleFontWeight="normal", titleFontSize=12)
    .configure_view(stroke="black")
)
Out[14]:

Make figure showing mutations selected for validation, with wildtype in black, mutants made in red, and other letters in faint blue:

In [15]:
# plot the differences in the experimentally validated sites
validation_sites = {
    '119(E2)': ["R", "K"],
    '120(E2)': ["K", "D"],
    '121(E2)': ["I", "E"],
    '157(E2)': ["A", "S"],
    '158(E2)': ["Q", "T", "V"],
}

validation_scatter_chart = []
for cell1, cell2 in itertools.combinations(reversed(cells), 2):
    validation_scatter_chart.append(
        alt.hconcat(
            *[
                plot_site_scatter(
                    s,
                    cell1,
                    cell2,
                    no_y_axis=(s != list(validation_sites)[0]),
                    ax_max=mut_data.query("site in @validation_sites")[[cell1, cell2]].max().max() + 0.6,
                    bold_letters=letters,
                )
                for s, letters in validation_sites.items()
            ],
            spacing=2,
        )
    )
validation_scatter_chart = (
    alt.vconcat(*validation_scatter_chart, spacing=15)
    .configure_axis(grid=False, titleFontWeight="normal", titleFontSize=12)
    .configure_view(stroke="black")
)
validation_scatter_chart
Out[15]:
In [ ]: